{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "45ff14a2",
   "metadata": {},
   "source": [
    "## What3D可视化\n",
    "\n",
    "3D Grand CAM可视化模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7fa11b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n",
    "\n",
    "import monai\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "\n",
    "mydir = '自己的目录'\n",
    "samples = [os.path.join(mydir, f) for f in os.listdir(mydir) if f.endswith('.nii') or f.endswith('.nii.gz')]\n",
    "\n",
    "# samples = [samples[-1]]\n",
    "samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a23129e7",
   "metadata": {},
   "source": [
    "## 确定可视化模型\n",
    "\n",
    "通过关键词获取要提取那一层进行可视化。\n",
    "\n",
    "### 支持的模型名称\n",
    "\n",
    "模型名称替换代码中的 `model_name`变量的值。\n",
    "\n",
    "| **模型系列** | **模型名称**                                                 |\n",
    "| ------------ | ------------------------------------------------------------ |\n",
    "| ResNet       | resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9a9194",
   "metadata": {},
   "outputs": [],
   "source": [
    "from monai.data import ImageDataset\n",
    "from torch.utils.data import DataLoader\n",
    "from onekey_algo.custom.components.comp2 import extract, init_from_onekey3d\n",
    "\n",
    "viz_dir = \"你自己训练的3D模型的viz目录。\"\n",
    "model, transformer, device = init_from_onekey3d(viz_dir)\n",
    "\n",
    "for n, m in model.named_modules():\n",
    "    print('Feature name:', n, \"|| Module:\", m)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e46336a1",
   "metadata": {},
   "source": [
    "## 可视化卷积层\n",
    "\n",
    "`Feature name:` 之后的名称为要可视化的层，例如`layer4.2.conv3`, 一般深度学习特征提取最后一层卷积层\n",
    "\n",
    "** 注意 ** : 可视化的层，一定为带有`conv`的卷积层，而且一般是最后一层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d9ad6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_layer = \"layer4.2.conv2\"\n",
    "gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)\n",
    "\n",
    "val_ds = ImageDataset(image_files=samples, transform=transformer)\n",
    "# create a validation data loader\n",
    "val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c9708b",
   "metadata": {},
   "source": [
    "## 打印可视化界面"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9eb2b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp2 import show_cam_on_image\n",
    "import torch\n",
    "from PIL import Image\n",
    "\n",
    "for sample_ in val_loader:\n",
    "    res_cam = gradcam(x=sample_.to(device), class_idx=None)\n",
    "    sample_np = sample_.cpu().detach().numpy()\n",
    "    for idx in range(sample_.size()[-1]):\n",
    "        fig, axes = plt.subplots(1, 2, figsize=(10, 5), facecolor='white')\n",
    "        axes[0].axis('off')\n",
    "        axes[0].imshow(sample_np[0][0][..., idx], cmap='gray')\n",
    "        img = Image.fromarray(sample_np[0][0][..., idx]*255).convert('RGB')\n",
    "        axes[1].axis('off')\n",
    "        imshow = axes[1].imshow(show_cam_on_image(img.resize(sample_np.shape[-3:-1]), -res_cam[0][0][..., idx].cpu().detach().numpy(),\n",
    "                                                  use_rgb=True, reverse=False), \n",
    "                            cmap='jet')\n",
    "        cax = fig.add_axes([0.92, 0.15, 0.02, axes[0].get_position().height]) \n",
    "        plt.colorbar(imshow, cax=cax)\n",
    "        plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
